"""Benchmark service for handling web-based benchmark execution."""

import hashlib
import json
import threading
import time
from datetime import UTC, datetime
from enum import Enum
from typing import Any, Dict, List, Optional

from loguru import logger

from ...api.research_functions import quick_summary
from ...database.models.benchmark import (
    BenchmarkResult,
    BenchmarkRun,
    BenchmarkStatus,
    DatasetType,
)
from ...web.services.socket_service import SocketIOService
from ..datasets import load_dataset
from ..graders import extract_answer_from_response, grade_single_result
from ..runners import format_query


class BenchmarkTaskStatus(Enum):
    """Status values for benchmark tasks in the queue tracker."""

    QUEUED = "queued"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"


class BenchmarkQueueTracker:
    """Simple in-memory tracker for benchmark queue status.

    This replaces the removed memory_queue functionality for benchmarks.
    Since benchmarks are temporary and don't need persistence,
    this simple in-memory solution is sufficient.

    Thread-safe for concurrent access from multiple benchmark threads.
    """

    def __init__(self):
        self.tasks = {}
        self._lock = threading.Lock()

    def add_task(
        self, task_id: str, username: str, task_type: str = "benchmark"
    ):
        """Add a new task to tracking.

        Also performs opportunistic cleanup of old completed tasks.
        """
        # Cleanup old tasks before adding new one (outside lock for better performance)
        self.cleanup_completed_tasks()

        with self._lock:
            self.tasks[task_id] = {
                "username": username,
                "task_type": task_type,
                "status": BenchmarkTaskStatus.QUEUED.value,
                "created_at": datetime.now(UTC),
            }

    def update_task_status(self, task_id: str, status: BenchmarkTaskStatus):
        """Update the status of a task."""
        with self._lock:
            if task_id in self.tasks:
                self.tasks[task_id]["status"] = status.value
                self.tasks[task_id]["updated_at"] = datetime.now(UTC)
            else:
                logger.warning(
                    f"Attempted to update status for non-existent task: {task_id}"
                )

    def get_task_status(self, task_id: str) -> Optional[Dict]:
        """Get the current status of a task."""
        with self._lock:
            return self.tasks.get(task_id)

    def remove_task(self, task_id: str):
        """Remove a task from tracking."""
        with self._lock:
            self.tasks.pop(task_id, None)

    def cleanup_completed_tasks(self, max_age_seconds: int = 3600):
        """Remove completed tasks older than max_age_seconds.

        Args:
            max_age_seconds: Maximum age in seconds for completed tasks (default 1 hour)
        """
        with self._lock:
            now = datetime.now(UTC)
            to_remove = []
            for task_id, task_data in self.tasks.items():
                # Only cleanup completed, failed, or cancelled tasks
                if task_data["status"] in [
                    BenchmarkTaskStatus.COMPLETED.value,
                    BenchmarkTaskStatus.FAILED.value,
                    BenchmarkTaskStatus.CANCELLED.value,
                ]:
                    # Check if task has updated_at timestamp
                    updated_at = task_data.get(
                        "updated_at", task_data.get("created_at")
                    )
                    if updated_at:
                        age = (now - updated_at).total_seconds()
                        if age > max_age_seconds:
                            to_remove.append(task_id)

            for task_id in to_remove:
                self.tasks.pop(task_id, None)
                logger.debug(f"Cleaned up old task: {task_id}")

            if to_remove:
                logger.info(f"Cleaned up {len(to_remove)} old benchmark tasks")


class BenchmarkService:
    """Service for managing benchmark runs through the web interface."""

    def __init__(self, socket_service=None):
        self.active_runs: Dict[int, Dict] = {}
        self.socket_service = socket_service or self._get_socket_service()
        self.rate_limit_detected: Dict[
            int, bool
        ] = {}  # Track rate limiting per benchmark run
        self.queue_tracker = BenchmarkQueueTracker()  # Initialize queue tracker

    def _get_socket_service(self):
        """Get socket service instance, handling cases where Flask app is not available."""
        try:
            return SocketIOService()
        except Exception:
            # Return a mock socket service for testing/standalone use
            class MockSocketService:
                def emit_to_room(self, room, event, data):
                    pass

            return MockSocketService()

    def generate_config_hash(self, search_config: Dict[str, Any]) -> str:
        """Generate a hash for search configuration compatibility checking."""
        relevant_params = {
            "iterations": search_config.get("iterations"),
            "questions_per_iteration": search_config.get(
                "questions_per_iteration"
            ),
            "search_tool": search_config.get("search_tool"),
            "search_strategy": search_config.get("search_strategy"),
            "model_name": search_config.get("model_name"),
            "provider": search_config.get("provider"),
        }
        # Remove None values
        relevant_params = {
            k: v for k, v in relevant_params.items() if v is not None
        }
        config_str = json.dumps(relevant_params, sort_keys=True)
        return hashlib.md5(  # DevSkim: ignore DS126858
            config_str.encode(), usedforsecurity=False
        ).hexdigest()[:8]

    def generate_query_hash(self, question: str, dataset_type: str) -> str:
        """Generate a hash for a query to enable deduplication."""
        query_content = f"{question.strip()}|{dataset_type.lower()}"
        return hashlib.md5(  # DevSkim: ignore DS126858
            query_content.encode(), usedforsecurity=False
        ).hexdigest()

    def create_benchmark_run(
        self,
        run_name: Optional[str],
        search_config: Dict[str, Any],
        evaluation_config: Dict[str, Any],
        datasets_config: Dict[str, Dict],
        username: str = None,
        user_password: str = None,
    ) -> int:
        """Create a new benchmark run in the database."""
        from ...database.session_context import get_user_db_session

        with get_user_db_session(username, user_password) as session:
            try:
                config_hash = self.generate_config_hash(search_config)

                # Calculate total examples
                total_examples = sum(
                    config.get("count", 0)
                    for config in datasets_config.values()
                )

                benchmark_run = BenchmarkRun(
                    run_name=run_name,
                    config_hash=config_hash,
                    query_hash_list=[],  # Will be populated as we process
                    search_config=search_config,
                    evaluation_config=evaluation_config,
                    datasets_config=datasets_config,
                    total_examples=total_examples,
                    status=BenchmarkStatus.PENDING,
                )

                session.add(benchmark_run)
                session.commit()

                logger.info(
                    f"Created benchmark run {benchmark_run.id} with config hash {config_hash}"
                )
                return benchmark_run.id

            except Exception:
                session.rollback()
                logger.exception("Error creating benchmark run")
                raise

    def get_existing_results(
        self, config_hash: str, username: str = None, user_password: str = None
    ) -> Dict[str, Dict]:
        """Get existing results with compatible configuration."""
        from ...database.session_context import get_user_db_session

        with get_user_db_session(username, user_password) as session:
            try:
                # Find compatible runs
                compatible_runs = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.config_hash == config_hash)
                    .filter(BenchmarkRun.status == BenchmarkStatus.COMPLETED)
                    .all()
                )

                existing_results = {}
                for run in compatible_runs:
                    results = (
                        session.query(BenchmarkResult)
                        .filter(BenchmarkResult.benchmark_run_id == run.id)
                        .filter(
                            BenchmarkResult.is_correct.isnot(None)
                        )  # Only completed evaluations
                        .all()
                    )

                    for result in results:
                        existing_results[result.query_hash] = {
                            "id": result.example_id,
                            "dataset_type": result.dataset_type.value,
                            "problem": result.question,
                            "correct_answer": result.correct_answer,
                            "response": result.response,
                            "extracted_answer": result.extracted_answer,
                            "confidence": result.confidence,
                            "processing_time": result.processing_time,
                            "sources": result.sources,
                            "is_correct": result.is_correct,
                            "graded_confidence": result.graded_confidence,
                            "grader_response": result.grader_response,
                            "query_hash": result.query_hash,
                        }

                logger.info(
                    f"Found {len(existing_results)} existing results for config hash {config_hash}"
                )
                return existing_results

            except Exception:
                logger.exception("Error loading existing results")
                return {}

    def start_benchmark(
        self,
        benchmark_run_id: int,
        username: str = None,
        user_password: str = None,
    ) -> bool:
        """Start a benchmark run in a background thread."""
        from ...database.session_context import get_user_db_session

        try:
            # Get all data from the database in the main thread
            # This avoids database access from the background thread
            with get_user_db_session(username, user_password) as session:
                # Get benchmark run details
                benchmark_run = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.id == benchmark_run_id)
                    .first()
                )
                if not benchmark_run:
                    raise ValueError(
                        f"Benchmark run {benchmark_run_id} not found"
                    )

                # Create settings snapshot for thread safety
                from local_deep_research.settings import SettingsManager

                settings_manager = SettingsManager(session)
                settings_snapshot = settings_manager.get_all_settings()

                # Get user password for metrics tracking in background thread
                from flask import session as flask_session
                from ...database.session_passwords import session_password_store

                user_password = None
                session_id = flask_session.get("session_id")
                if session_id and username:
                    user_password = session_password_store.get_session_password(
                        username, session_id
                    )
                    if not user_password:
                        logger.warning(
                            f"No password found for user {username} in current session"
                        )

                # Extract all data we need
                benchmark_data = {
                    "benchmark_run_id": benchmark_run_id,
                    "username": username or "benchmark_user",
                    "user_password": user_password,  # Add password for metrics tracking
                    "config_hash": benchmark_run.config_hash,
                    "datasets_config": benchmark_run.datasets_config,
                    "search_config": benchmark_run.search_config,
                    "evaluation_config": benchmark_run.evaluation_config,
                    "existing_results": self.get_existing_results(
                        benchmark_run.config_hash, username, user_password
                    ),
                    "settings_snapshot": settings_snapshot,  # Add settings snapshot
                }

                # Update status in database
                benchmark_run.status = BenchmarkStatus.IN_PROGRESS
                benchmark_run.start_time = datetime.now(UTC)
                session.commit()

            # Store data in memory for the thread
            self.active_runs[benchmark_run_id] = {
                "data": benchmark_data,
                "start_time": datetime.now(UTC),
                "status": "running",
                "results": [],
            }

            # Start background thread
            thread = threading.Thread(
                target=self._run_benchmark_thread,
                args=(benchmark_run_id,),
                daemon=True,
            )
            thread.start()

            self.active_runs[benchmark_run_id]["thread"] = thread

            logger.info(f"Started benchmark run {benchmark_run_id}")
            return True

        except Exception as e:
            logger.exception(f"Error starting benchmark {benchmark_run_id}")
            # Update status using user database
            with get_user_db_session(username, user_password) as session:
                benchmark_run = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.id == benchmark_run_id)
                    .first()
                )
                if benchmark_run:
                    benchmark_run.status = BenchmarkStatus.FAILED
                    benchmark_run.error_message = str(e)
                    session.commit()
            return False

    def _run_benchmark_thread(self, benchmark_run_id: int):
        """Main benchmark execution thread."""
        # IMPORTANT: This runs in a background thread, so we cannot access the user database
        # Using in-memory queue tracker for benchmark status tracking

        task_id = None
        try:
            # Get the benchmark data that was passed to us
            # We need to retrieve this from the service database or from memory
            benchmark_data = self.active_runs.get(benchmark_run_id, {}).get(
                "data"
            )

            if not benchmark_data:
                raise ValueError(
                    f"Benchmark data for run {benchmark_run_id} not found"
                )

            # Set up settings context for thread-local access
            settings_snapshot = benchmark_data.get("settings_snapshot", {})
            username = benchmark_data.get("username", "benchmark_user")

            # Create a settings context that threads can use
            class SettingsContext:
                def __init__(self, snapshot, username):
                    self.snapshot = snapshot or {}
                    self.username = username
                    # Extract values from setting objects if needed
                    self.values = {}
                    for key, setting in self.snapshot.items():
                        if isinstance(setting, dict) and "value" in setting:
                            # It's a full setting object, extract the value
                            self.values[key] = setting["value"]
                        else:
                            # It's already just a value
                            self.values[key] = setting

                def get_setting(self, key, default=None):
                    """Get setting from snapshot only - no database access in threads"""
                    if key in self.values:
                        return self.values[key]
                    # No fallback to database - threads must use snapshot only
                    logger.warning(
                        f"Setting '{key}' not found in snapshot, using default: {default}"
                    )
                    return default

            settings_context = SettingsContext(settings_snapshot, username)

            # Set the context in thread-local storage
            from ...config.thread_settings import set_settings_context

            set_settings_context(settings_context)

            # Extract all the data we need
            datasets_config = benchmark_data["datasets_config"]
            search_config = benchmark_data["search_config"]
            evaluation_config = benchmark_data["evaluation_config"]
            existing_results = benchmark_data.get("existing_results", {})

            # Create task queue
            task_queue = self._create_task_queue(
                datasets_config,
                existing_results,
                benchmark_run_id,
            )

            # Calculate totals
            total_examples = len(task_queue) + len(existing_results)
            completed_examples = len(existing_results)

            # Initialize task tracking
            task_id = f"benchmark_{benchmark_run_id}_{int(datetime.now(UTC).timestamp())}"
            username = benchmark_data.get("username", "benchmark_user")
            self.queue_tracker.add_task(task_id, username, "benchmark")
            self.queue_tracker.update_task_status(
                task_id, BenchmarkTaskStatus.PROCESSING
            )

            # Track progress in memory
            progress_info = {
                "total_examples": total_examples,
                "completed_examples": completed_examples,
                "failed_examples": 0,
                "start_time": datetime.now(UTC),
            }

            # Process tasks
            logger.info(
                f"Benchmark {benchmark_run_id} starting to process {len(task_queue)} tasks"
            )
            for i, task in enumerate(task_queue):
                # Check if benchmark has been cancelled
                if (
                    benchmark_run_id in self.active_runs
                    and self.active_runs[benchmark_run_id].get("status")
                    == "cancelled"
                ):
                    logger.info(
                        f"Benchmark {benchmark_run_id} was cancelled, stopping processing"
                    )
                    break

                logger.info(
                    f"Benchmark {benchmark_run_id} processing task {i + 1}/{len(task_queue)}"
                )
                try:
                    # Add username and password to task for metrics tracking
                    task["username"] = benchmark_data.get("username")
                    task["user_password"] = benchmark_data.get("user_password")

                    # Process single task
                    result = self._process_benchmark_task(
                        task,
                        search_config,
                        evaluation_config,
                    )

                    # Store result in memory for now (will be saved later)
                    if "results" not in self.active_runs[benchmark_run_id]:
                        self.active_runs[benchmark_run_id]["results"] = []
                    self.active_runs[benchmark_run_id]["results"].append(result)

                    # Update progress
                    progress_info["completed_examples"] += 1

                    logger.info(
                        f"Benchmark {benchmark_run_id} task {i + 1}/{len(task_queue)} completed successfully. "
                        f"Progress: {progress_info['completed_examples']}/{progress_info['total_examples']} total examples"
                    )

                    # Send real-time update
                    self._send_progress_update(
                        benchmark_run_id,
                        progress_info["completed_examples"],
                        progress_info["total_examples"],
                    )

                except Exception as e:
                    logger.exception(f"Error processing task {i}: {str(e)}")
                    progress_info["failed_examples"] += 1
                    logger.info(
                        f"Benchmark {benchmark_run_id} task {i + 1}/{len(task_queue)} failed. "
                        f"Total failed: {progress_info['failed_examples']}"
                    )

                    # Check if this is a rate limiting error
                    error_str = str(e).lower()
                    if (
                        "403" in error_str
                        or "rate limit" in error_str
                        or "forbidden" in error_str
                    ):
                        self.rate_limit_detected[benchmark_run_id] = True
                        # Send rate limit warning via WebSocket
                        self.socket_service.emit_to_subscribers(
                            "research_progress",
                            benchmark_run_id,
                            {
                                "rate_limit_detected": True,
                                "message": "SearXNG rate limiting detected",
                            },
                        )

            # Mark as completed in memory tracker
            progress_info["end_time"] = datetime.now(UTC)

            # Check if benchmark was cancelled
            was_cancelled = (
                benchmark_run_id in self.active_runs
                and self.active_runs[benchmark_run_id].get("status")
                == "cancelled"
            )

            if was_cancelled:
                status = BenchmarkStatus.CANCELLED
                message = "Benchmark cancelled by user"
                if task_id:
                    self.queue_tracker.update_task_status(
                        task_id, BenchmarkTaskStatus.CANCELLED
                    )
            else:
                status = BenchmarkStatus.COMPLETED
                message = "Benchmark completed successfully"
                if task_id:
                    self.queue_tracker.update_task_status(
                        task_id, BenchmarkTaskStatus.COMPLETED
                    )

            # Store completion info for later database update
            self.active_runs[benchmark_run_id]["completion_info"] = {
                "status": status,
                "end_time": progress_info["end_time"],
                "completed_examples": progress_info["completed_examples"],
                "failed_examples": progress_info["failed_examples"],
            }

            # Send completion notification
            self.socket_service.emit_to_subscribers(
                "research_progress",
                benchmark_run_id,
                {
                    "status": "cancelled" if was_cancelled else "completed",
                    "message": message,
                    "progress": (
                        progress_info["completed_examples"]
                        / progress_info["total_examples"]
                        * 100
                    )
                    if progress_info["total_examples"] > 0
                    else 0,
                    "benchmark_run_id": benchmark_run_id,
                },
            )

        except Exception as e:
            logger.exception(f"Benchmark run {benchmark_run_id} failed")
            # Update task status if we have a task_id
            if task_id:
                self.queue_tracker.update_task_status(
                    task_id, BenchmarkTaskStatus.FAILED
                )
            # Store failure info for later database update
            if benchmark_run_id in self.active_runs:
                self.active_runs[benchmark_run_id]["completion_info"] = {
                    "status": BenchmarkStatus.FAILED,
                    "error_message": str(e),
                }
        finally:
            # Clean up active run tracking
            if benchmark_run_id in self.active_runs:
                # Mark that thread is done but keep data for database update
                self.active_runs[benchmark_run_id]["thread_complete"] = True

                # Try to save results to database immediately if possible
                self._sync_results_to_database(benchmark_run_id)

    def _create_task_queue(
        self,
        datasets_config: Dict,
        existing_results: Dict,
        benchmark_run_id: int,
    ) -> List[Dict]:
        """Create list of tasks to process, excluding existing results."""
        tasks = []

        for dataset_name, config in datasets_config.items():
            if config.get("count", 0) > 0:
                dataset = load_dataset(
                    dataset_type=dataset_name,
                    num_examples=config["count"],
                    seed=None,
                )

                for i, example in enumerate(dataset):
                    # Extract question based on dataset type
                    if dataset_name.lower() == "simpleqa":
                        question = example.get("problem", "")
                        correct_answer = example.get("answer", "")
                    else:  # browsecomp
                        question = example.get("problem", "")
                        correct_answer = example.get("answer", "")

                    # Generate query hash
                    query_hash = self.generate_query_hash(
                        question, dataset_name
                    )

                    # Skip if already processed
                    if query_hash in existing_results:
                        continue

                    tasks.append(
                        {
                            "benchmark_run_id": benchmark_run_id,
                            "example_id": example.get("id", f"example_{i}"),
                            "dataset_type": dataset_name,
                            "question": question,
                            "correct_answer": correct_answer,
                            "query_hash": query_hash,
                            "task_index": len(tasks),
                        }
                    )

        return tasks

    def _process_benchmark_task(
        self, task: Dict, search_config: Dict, evaluation_config: Dict
    ) -> Dict:
        """Process a single benchmark task."""
        try:
            logger.info(
                f"Starting benchmark task {task['task_index'] + 1}: "
                f"example_id={task['example_id']}, dataset={task['dataset_type']}, "
                f"question_preview='{task['question'][:100]}...'"
            )

            # Get settings context from thread-local storage
            from ...config.thread_settings import get_settings_context

            settings_context = get_settings_context()

            # Generate a unique tracking ID for this benchmark task
            import uuid

            tracking_id = str(uuid.uuid4())
            logger.info(
                f"Task {task['example_id']} assigned tracking_id: {tracking_id}"
            )

            # Format query
            formatted_query = format_query(
                task["question"], task["dataset_type"]
            )
            logger.info(
                f"Task {task['example_id']} formatted query: '{formatted_query[:150]}...'"
            )

            # Run research with progress callback for WebSocket updates
            start_time = time.time()
            logger.info(f"Task {task['example_id']} starting research phase...")

            def benchmark_progress_callback(
                status: str, progress: int, data: dict
            ):
                """Progress callback to emit detailed research progress via WebSocket"""
                try:
                    timestamp = datetime.now(UTC).isoformat()

                    # Create research-compatible log entry
                    log_entry = {
                        "time": timestamp,
                        "message": f"Example {task['example_id']}: {status}",
                        "progress": progress,
                        "metadata": {
                            "phase": data.get("phase", "benchmark_processing"),
                            "type": data.get("type", "info"),
                            "example_id": task["example_id"],
                            "benchmark_run_id": task["benchmark_run_id"],
                            **data,  # Include all other data
                        },
                    }

                    # Determine log type based on status/message content
                    if (
                        "complete" in status.lower()
                        or "finished" in status.lower()
                    ):
                        log_entry["metadata"]["type"] = "milestone"
                    elif (
                        "error" in status.lower() or "failed" in status.lower()
                    ):
                        log_entry["metadata"]["type"] = "error"
                    elif (
                        "starting" in status.lower()
                        or "begin" in status.lower()
                    ):
                        log_entry["metadata"]["type"] = "milestone"

                    # Create progress data in research format
                    progress_data = {
                        "progress": progress,
                        "message": status,
                        "status": "in_progress",
                        "log_entry": log_entry,
                        "progress_log": json.dumps(
                            [log_entry]
                        ),  # Array format expected by socket.js
                    }

                    # Emit using research_progress format that the UI expects
                    self.socket_service.emit_to_subscribers(
                        "research_progress",
                        task["benchmark_run_id"],
                        progress_data,
                    )

                except Exception:
                    logger.exception("Error sending benchmark progress update")

            # Get user password from task data
            user_password = task.get("user_password")

            search_result = quick_summary(
                query=formatted_query,
                research_id=tracking_id,  # Pass the tracking ID
                iterations=search_config.get("iterations", 8),
                questions_per_iteration=search_config.get(
                    "questions_per_iteration", 5
                ),
                search_tool=search_config.get("search_tool", "searxng"),
                search_strategy=search_config.get(
                    "search_strategy", "focused_iteration"
                ),
                progress_callback=benchmark_progress_callback,
                model_name=search_config.get("model_name"),
                provider=search_config.get("provider"),
                temperature=search_config.get("temperature", 0.7),
                openai_endpoint_url=search_config.get("openai_endpoint_url"),
                settings_snapshot=settings_context.snapshot,  # Pass settings snapshot for thread safety
                username=task.get("username"),  # Pass username
                user_password=user_password,  # Pass password for metrics tracking
            )
            processing_time = time.time() - start_time
            logger.info(
                f"Task {task['example_id']} research completed in {processing_time:.2f}s, "
                f"model={search_config.get('model_name')}, provider={search_config.get('provider')}"
            )

            # Extract answer
            response = search_result.get("summary", "")
            logger.info(
                f"Task {task['example_id']} response length: {len(response)} chars"
            )

            extracted_data = extract_answer_from_response(
                response, task["dataset_type"]
            )
            extracted_answer = (
                extracted_data.get("extracted_answer", "")
                if isinstance(extracted_data, dict)
                else str(extracted_data)
            )
            logger.info(
                f"Task {task['example_id']} extracted answer: '{extracted_answer[:100]}...'"
            )

            # Extract sources - handle both direct sources and all_links_of_system
            sources = search_result.get("sources", [])
            if not sources and "all_links_of_system" in search_result:
                sources = search_result.get("all_links_of_system", [])

            # Log for debugging
            logger.debug(f"Search result keys: {list(search_result.keys())}")
            logger.debug(f"Sources found: {len(sources)} items")

            # Prepare result
            result = {
                **task,
                "response": response,
                "extracted_answer": extracted_answer,
                "confidence": str(
                    extracted_data.get("confidence", "100")
                    if isinstance(extracted_data, dict)
                    else "100"
                ),
                "processing_time": processing_time,
                "sources": json.dumps(sources),  # Convert to JSON string
                "completed_at": datetime.now(UTC),
                "research_id": tracking_id,  # Store the UUID in the research_id field
            }

            # Evaluate result - requires proper grading model
            try:
                logger.info(f"Task {task['example_id']} starting evaluation...")
                eval_start_time = time.time()

                # Always attempt evaluation, regardless of provider
                # Modern local models like Ollama are capable of grading
                # Try to evaluate with proper model
                result_data = {
                    "id": task["example_id"],
                    "problem": task["question"],
                    "correct_answer": task["correct_answer"],
                    "response": response,
                    "extracted_answer": extracted_answer,
                }

                eval_result = grade_single_result(
                    result_data,
                    task["dataset_type"],
                    evaluation_config,
                    settings_context.snapshot,
                )
                eval_time = time.time() - eval_start_time
                logger.info(
                    f"Task {task['example_id']} evaluation completed in {eval_time:.2f}s"
                )
                if eval_result and not eval_result.get("grading_error"):
                    result.update(
                        {
                            "is_correct": eval_result.get("is_correct", False),
                            "graded_confidence": eval_result.get(
                                "graded_confidence", "0"
                            ),
                            "grader_response": eval_result.get(
                                "grader_response", ""
                            ),
                        }
                    )
                else:
                    error_msg = (
                        eval_result.get(
                            "grading_error", "Unknown evaluation error"
                        )
                        if eval_result
                        else "No evaluation results returned"
                    )
                    result.update(
                        {
                            "is_correct": None,
                            "graded_confidence": "0",
                            "grader_response": f"Evaluation failed: {error_msg}",
                            "evaluation_error": error_msg,
                        }
                    )

            except Exception as e:
                logger.exception("Evaluation error")
                result.update(
                    {
                        "is_correct": None,
                        "graded_confidence": "0",
                        "grader_response": f"Evaluation failed: {e!s}",
                        "evaluation_error": str(e),
                    }
                )

            return result

        except Exception as e:
            logger.exception("Research error")
            return {
                **task,
                "research_error": str(e),
                "completed_at": datetime.now(UTC),
            }

    def sync_pending_results(self, benchmark_run_id: int, username: str = None):
        """Sync any pending results to database. Can be called from main thread."""
        if benchmark_run_id not in self.active_runs:
            return 0

        run_data = self.active_runs[benchmark_run_id]
        results_to_save = run_data.get("results", [])
        saved_indices = run_data.get("saved_indices", set())

        if not username:
            username = run_data.get("data", {}).get("username")

        user_password = run_data.get("data", {}).get("user_password")

        saved_count = 0
        from ...database.session_context import get_user_db_session
        from ...database.models.benchmark import BenchmarkResult

        try:
            with get_user_db_session(username, user_password) as session:
                # Save any results that haven't been saved yet
                for idx, result in enumerate(results_to_save):
                    if idx not in saved_indices:
                        # Check if this result already exists in the database
                        existing = (
                            session.query(BenchmarkResult)
                            .filter_by(
                                benchmark_run_id=benchmark_run_id,
                                query_hash=result["query_hash"],
                            )
                            .first()
                        )

                        if existing:
                            # Skip if already exists
                            saved_indices.add(idx)
                            continue

                        benchmark_result = BenchmarkResult(
                            benchmark_run_id=benchmark_run_id,
                            example_id=result["example_id"],
                            query_hash=result["query_hash"],
                            dataset_type=DatasetType(result["dataset_type"]),
                            research_id=result.get("research_id"),
                            question=result["question"],
                            correct_answer=result["correct_answer"],
                            response=result.get("response"),
                            extracted_answer=result.get("extracted_answer"),
                            confidence=result.get("confidence"),
                            processing_time=result.get("processing_time"),
                            sources=result.get("sources"),
                            is_correct=result.get("is_correct"),
                            graded_confidence=result.get("graded_confidence"),
                            grader_response=result.get("grader_response"),
                            completed_at=result.get("completed_at"),
                            research_error=result.get("research_error"),
                            evaluation_error=result.get("evaluation_error"),
                            task_index=result.get("task_index"),
                        )
                        session.add(benchmark_result)
                        saved_indices.add(idx)
                        saved_count += 1

                if saved_count > 0:
                    session.commit()
                    run_data["saved_indices"] = saved_indices
                    logger.info(
                        f"Saved {saved_count} new results for benchmark {benchmark_run_id}"
                    )

        except Exception:
            logger.exception(
                f"Error syncing pending results for benchmark {benchmark_run_id}"
            )
            # Roll back the session on error to prevent PendingRollbackError
            try:
                session.rollback()
            except:
                pass

        return saved_count

    def _sync_results_to_database(self, benchmark_run_id: int):
        """Sync benchmark results from memory to database after thread completes."""
        if benchmark_run_id not in self.active_runs:
            return

        run_data = self.active_runs[benchmark_run_id]
        if not run_data.get("thread_complete"):
            return

        username = run_data.get("data", {}).get("username")
        user_password = run_data.get("data", {}).get("user_password")
        from ...database.session_context import get_user_db_session

        try:
            with get_user_db_session(username, user_password) as session:
                # Update benchmark run status
                benchmark_run = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.id == benchmark_run_id)
                    .first()
                )

                if benchmark_run and "completion_info" in run_data:
                    info = run_data["completion_info"]
                    benchmark_run.status = info["status"]
                    benchmark_run.end_time = info.get(
                        "end_time", datetime.now(UTC)
                    )
                    benchmark_run.completed_examples = info.get(
                        "completed_examples", 0
                    )
                    benchmark_run.failed_examples = info.get(
                        "failed_examples", 0
                    )
                    benchmark_run.error_message = info.get("error_message")

                    # Save all results (skip already saved ones)
                    saved_indices = run_data.get("saved_indices", set())
                    for idx, result in enumerate(run_data.get("results", [])):
                        if idx in saved_indices:
                            continue
                        benchmark_result = BenchmarkResult(
                            benchmark_run_id=benchmark_run_id,
                            example_id=result["example_id"],
                            query_hash=result["query_hash"],
                            dataset_type=DatasetType(result["dataset_type"]),
                            research_id=result.get("research_id"),
                            question=result["question"],
                            correct_answer=result["correct_answer"],
                            response=result.get("response"),
                            extracted_answer=result.get("extracted_answer"),
                            confidence=result.get("confidence"),
                            processing_time=result.get("processing_time"),
                            sources=result.get("sources"),
                            is_correct=result.get("is_correct"),
                            graded_confidence=result.get("graded_confidence"),
                            grader_response=result.get("grader_response"),
                            completed_at=result.get("completed_at"),
                            research_error=result.get("research_error"),
                            evaluation_error=result.get("evaluation_error"),
                            task_index=result.get("task_index"),
                        )
                        session.add(benchmark_result)

                    # Calculate final accuracy
                    if benchmark_run.status == BenchmarkStatus.COMPLETED:
                        correct_results = [
                            r
                            for r in run_data.get("results", [])
                            if r.get("is_correct")
                        ]
                        evaluated_results = [
                            r
                            for r in run_data.get("results", [])
                            if r.get("is_correct") is not None
                        ]

                        if evaluated_results:
                            benchmark_run.overall_accuracy = (
                                len(correct_results) / len(evaluated_results)
                            ) * 100

                            # Calculate processing rate
                            total_time = sum(
                                r.get("processing_time", 0)
                                for r in evaluated_results
                            )
                            if total_time > 0:
                                benchmark_run.processing_rate = len(
                                    evaluated_results
                                ) / (total_time / 60)

                    session.commit()
                    logger.info(
                        f"Successfully synced results for benchmark {benchmark_run_id}"
                    )

            # Clean up memory
            del self.active_runs[benchmark_run_id]

        except Exception:
            logger.exception("Error syncing benchmark results to database")

    def _send_progress_update(
        self, benchmark_run_id: int, completed: int, total: int
    ):
        """Send real-time progress update via websocket."""
        try:
            percentage = (completed / total * 100) if total > 0 else 0

            # Create log entry for milestone progress
            log_entry = {
                "time": datetime.now(UTC).isoformat(),
                "message": f"Completed {completed}/{total} examples ({percentage:.1f}%)",
                "progress": percentage,
                "metadata": {
                    "phase": "benchmark_progress",
                    "type": "milestone",
                    "completed": completed,
                    "total": total,
                    "benchmark_run_id": benchmark_run_id,
                },
            }

            progress_data = {
                "status": "in_progress",
                "message": f"Processing examples: {completed}/{total}",
                "progress": percentage,
                "completed": completed,
                "total": total,
                "benchmark_run_id": benchmark_run_id,
                "log_entry": log_entry,
                "progress_log": json.dumps([log_entry]),
            }

            self.socket_service.emit_to_subscribers(
                "research_progress", benchmark_run_id, progress_data
            )

        except Exception:
            logger.exception("Error sending progress update")

    def _calculate_final_accuracy(
        self, benchmark_run_id: int, username: str = None
    ):
        """Calculate and save final accuracy metrics."""
        from ...database.session_context import get_user_db_session

        with get_user_db_session(username) as session:
            try:
                # Get all results for this run
                results = (
                    session.query(BenchmarkResult)
                    .filter(
                        BenchmarkResult.benchmark_run_id == benchmark_run_id
                    )
                    .filter(BenchmarkResult.is_correct.isnot(None))
                    .all()
                )

                if results:
                    correct_count = sum(1 for r in results if r.is_correct)
                    overall_accuracy = (correct_count / len(results)) * 100

                    # Calculate processing rate
                    total_time = sum(r.processing_time or 0 for r in results)
                    processing_rate = (
                        (len(results) / (total_time / 60))
                        if total_time > 0
                        else 0
                    )

                    # Update benchmark run
                    benchmark_run = (
                        session.query(BenchmarkRun)
                        .filter(BenchmarkRun.id == benchmark_run_id)
                        .first()
                    )
                    if benchmark_run:
                        benchmark_run.overall_accuracy = overall_accuracy
                        benchmark_run.processing_rate = processing_rate
                        session.commit()

            except Exception:
                logger.exception("Error calculating final accuracy")

    def update_benchmark_status(
        self,
        benchmark_run_id: int,
        status: BenchmarkStatus,
        error_message: str = None,
        username: str = None,
    ):
        """Update benchmark run status."""
        from ...database.session_context import get_user_db_session

        with get_user_db_session(username) as session:
            try:
                benchmark_run = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.id == benchmark_run_id)
                    .first()
                )
                if benchmark_run:
                    benchmark_run.status = status
                    benchmark_run.updated_at = datetime.now(UTC)

                    if error_message:
                        benchmark_run.error_message = error_message

                    if (
                        status == BenchmarkStatus.IN_PROGRESS
                        and not benchmark_run.start_time
                    ):
                        benchmark_run.start_time = datetime.now(UTC)
                    elif (
                        status
                        in [BenchmarkStatus.COMPLETED, BenchmarkStatus.FAILED]
                        and not benchmark_run.end_time
                    ):
                        benchmark_run.end_time = datetime.now(UTC)

                    session.commit()

            except Exception:
                session.rollback()
                logger.exception("Error updating benchmark status")

    def get_benchmark_status(
        self, benchmark_run_id: int, username: str = None
    ) -> Optional[Dict]:
        """Get current status of a benchmark run."""
        from ...database.session_context import get_user_db_session

        with get_user_db_session(username) as session:
            try:
                benchmark_run = (
                    session.query(BenchmarkRun)
                    .filter(BenchmarkRun.id == benchmark_run_id)
                    .first()
                )
                if not benchmark_run:
                    return None

                # Calculate running accuracy from current results AND reused results from compatible runs
                # First get results specifically for this benchmark run
                current_results = (
                    session.query(BenchmarkResult)
                    .filter(
                        BenchmarkResult.benchmark_run_id == benchmark_run_id
                    )
                    .filter(BenchmarkResult.is_correct.isnot(None))
                    .all()
                )

                # Then get reused results from compatible benchmark runs (same config hash)
                # Only count results up to the number we say we've "completed"
                if benchmark_run.completed_examples > len(current_results):
                    # We have reused results, get them from compatible runs
                    reused_count_needed = (
                        benchmark_run.completed_examples - len(current_results)
                    )

                    compatible_results = (
                        session.query(BenchmarkResult)
                        .join(
                            BenchmarkRun,
                            BenchmarkResult.benchmark_run_id == BenchmarkRun.id,
                        )
                        .filter(
                            BenchmarkRun.config_hash
                            == benchmark_run.config_hash
                        )
                        .filter(
                            BenchmarkRun.id != benchmark_run_id
                        )  # Exclude current run
                        .filter(
                            BenchmarkRun.status == BenchmarkStatus.COMPLETED
                        )
                        .filter(BenchmarkResult.is_correct.isnot(None))
                        .order_by(BenchmarkResult.id)  # Consistent ordering
                        .limit(reused_count_needed)
                        .all()
                    )

                    # Combine current and reused results
                    results = (
                        current_results
                        + compatible_results[:reused_count_needed]
                    )
                else:
                    # No reused results, just use current results
                    results = current_results

                running_accuracy = None
                # Dynamic per-dataset accuracy tracking
                dataset_accuracies = {}

                if results:
                    # Overall running accuracy
                    correct_count = sum(1 for r in results if r.is_correct)
                    running_accuracy = (correct_count / len(results)) * 100

                    # Calculate accuracy for each dataset type dynamically
                    from collections import defaultdict

                    dataset_results = defaultdict(list)

                    # Group results by dataset type
                    for r in results:
                        dataset_results[r.dataset_type.value].append(r)

                    # Calculate accuracy for each dataset
                    for (
                        dataset_type,
                        dataset_result_list,
                    ) in dataset_results.items():
                        if dataset_result_list:
                            correct = sum(
                                1 for r in dataset_result_list if r.is_correct
                            )
                            accuracy = (
                                correct / len(dataset_result_list)
                            ) * 100
                            # Store with _accuracy suffix for consistency
                            dataset_accuracies[f"{dataset_type}_accuracy"] = (
                                accuracy
                            )

                # Calculate time estimates and reliability metrics
                estimated_time_remaining = None
                total_elapsed_time = None
                avg_time_per_example = None
                accuracy_confidence = None

                # Get ALL results for timing calculation (including those pending evaluation)
                all_results_for_timing = (
                    session.query(BenchmarkResult)
                    .filter(
                        BenchmarkResult.benchmark_run_id == benchmark_run_id
                    )
                    .all()
                )

                if benchmark_run.start_time and all_results_for_timing:
                    # Calculate elapsed time
                    current_time = datetime.now(UTC)
                    total_elapsed_time = (
                        current_time - benchmark_run.start_time
                    ).total_seconds()

                    # Calculate average processing time per example using actual count
                    avg_time_per_example = total_elapsed_time / len(
                        all_results_for_timing
                    )

                    logger.info(
                        f"Time calculation - elapsed: {total_elapsed_time:.2f}s, "
                        f"results_count: {len(all_results_for_timing)}, "
                        f"avg_per_example: {avg_time_per_example:.2f}s"
                    )

                    # Estimate remaining time
                    remaining_examples = benchmark_run.total_examples - len(
                        all_results_for_timing
                    )
                    if remaining_examples > 0:
                        estimated_time_remaining = (
                            avg_time_per_example * remaining_examples
                        )
                        logger.info(
                            f"Time estimation - total: {benchmark_run.total_examples}, "
                            f"completed: {len(all_results_for_timing)}, remaining: {remaining_examples}, "
                            f"avg_time: {avg_time_per_example:.2f}s, "
                            f"estimated_remaining: {estimated_time_remaining:.2f}s"
                        )

                # Calculate accuracy confidence interval (95% confidence)
                if results and len(results) >= 3:
                    import math

                    n = len(results)
                    p = running_accuracy / 100 if running_accuracy else 0
                    # Standard error for proportion
                    se = math.sqrt(p * (1 - p) / n)
                    # 95% confidence interval (±1.96 * SE)
                    margin_of_error = 1.96 * se * 100
                    accuracy_confidence = {
                        "lower_bound": max(
                            0, running_accuracy - margin_of_error
                        ),
                        "upper_bound": min(
                            100, running_accuracy + margin_of_error
                        ),
                        "margin_of_error": margin_of_error,
                        "sample_size": n,
                    }

                status_data = {
                    "id": benchmark_run.id,
                    "run_name": benchmark_run.run_name,
                    "status": benchmark_run.status.value,
                    "completed_examples": len(
                        all_results_for_timing
                    ),  # Use actual count from DB
                    "total_examples": benchmark_run.total_examples,
                    "failed_examples": benchmark_run.failed_examples,
                    "overall_accuracy": benchmark_run.overall_accuracy
                    or running_accuracy,  # Use running accuracy if final not calculated
                    "running_accuracy": running_accuracy,  # Current running accuracy
                    "processing_rate": benchmark_run.processing_rate,
                    "estimated_time_remaining": estimated_time_remaining,  # seconds
                    "total_elapsed_time": total_elapsed_time,  # seconds
                    "avg_time_per_example": avg_time_per_example,  # seconds
                    "accuracy_confidence": accuracy_confidence,  # confidence interval
                    "created_at": benchmark_run.created_at.isoformat()
                    if benchmark_run.created_at
                    else None,
                    "start_time": benchmark_run.start_time.isoformat()
                    if benchmark_run.start_time
                    else None,
                    "end_time": benchmark_run.end_time.isoformat()
                    if benchmark_run.end_time
                    else None,
                    "error_message": benchmark_run.error_message,
                    # Add all per-dataset accuracies dynamically
                    **dataset_accuracies,
                }

                logger.info(
                    f"Benchmark {benchmark_run_id} status - completed: {benchmark_run.completed_examples}, "
                    f"running_acc: {running_accuracy}, dataset_accuracies: {dataset_accuracies}, "
                    f"avg_time: {avg_time_per_example}"
                )

                return status_data

            except Exception:
                logger.exception("Error getting benchmark status")
                return None

    def cancel_benchmark(
        self, benchmark_run_id: int, username: str = None
    ) -> bool:
        """Cancel a running benchmark."""
        try:
            if benchmark_run_id in self.active_runs:
                self.active_runs[benchmark_run_id]["status"] = "cancelled"

            self.update_benchmark_status(
                benchmark_run_id, BenchmarkStatus.CANCELLED, username=username
            )
            logger.info(f"Cancelled benchmark run {benchmark_run_id}")
            return True

        except Exception:
            logger.exception(f"Error cancelling benchmark {benchmark_run_id}")
            return False


# Global service instance
benchmark_service = BenchmarkService()
